{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 音乐网站用户流失预测 -- Logistic 回归，测试\n",
    "\n",
    "### 数据集说明\n",
    "\n",
    "项目提供KKBOX用户——歌曲重复播放记录，以及用户和歌曲的元数据。训练数据由2017年2月服务到期的用户构成，target标签代表用户在2017年3月是否续订了业务。测试集中的数据由2017年3月内将到期的用户构成，需要预测用户是否在到期后的一个月内即2017年4月预定、流失的概率。\n",
    "\n",
    "以下是文件及字段说明：\n",
    "\n",
    "1. train.csv: 训练数据，共7,377,418条记录\n",
    "\n",
    "    msno: 用户id，加密String  \n",
    "\n",
    "    song_id: song id，歌曲id\n",
    "\n",
    "    source_system_tab: 触发事件的类型/tab，用于表示app的功能类型\n",
    "\n",
    "    source_screen_name: 用户看到的布局的名字（name of the layout）\n",
    "\n",
    "    source_type: 用户在app上播放音乐的入口的类型\n",
    "\n",
    "    target: 标签。1表示用户在第一次听音乐后会在一个月内继续订阅，0表示没有订阅。\n",
    "\n",
    "2. test.csv ：测试数据，共2,556,790条记录\n",
    "\n",
    "    id: id (用于结果提交)\n",
    "\n",
    "    msno: 用户id\n",
    "\n",
    "    song_id: 歌曲id\n",
    "\n",
    "    source_system_tab: 触发事件的类型/tab，用于表示app的功能类型\n",
    "\n",
    "    source_screen_name: 用户看到的布局的名字（name of the layout）\n",
    "\n",
    "    source_type: 用户在app上播放音乐的入口的类型\n",
    "\n",
    "3. sampleSubmission.csv：提交结果文件样例  \n",
    "\n",
    "    提交测试结果包含两个字段，分别为测试样本id及其标签为1的概率，格式如下：\n",
    "\n",
    "    id,target\n",
    "    \n",
    "    2,0.3\n",
    "    \n",
    "    5,0.1\n",
    "    \n",
    "    6,1\n",
    "    \n",
    "    etc.\n",
    "\n",
    "4. songs.csv：歌曲元数据信息，用unicode编码\n",
    "\n",
    "    song_id：歌曲id\n",
    "\n",
    "    song_length: 单位为ms\n",
    "\n",
    "    genre_ids: genre 类别. 可多选，用 “|“隔开\n",
    "\n",
    "    artist_name：歌手\n",
    "\n",
    "    composer：作曲\n",
    "\n",
    "    lyricist：作词\n",
    "\n",
    "    language：语言\n",
    "\n",
    "5. members.csv：用户元数据信息\n",
    "\n",
    "    msno：用户id\n",
    "\n",
    "    city：城市\n",
    "\n",
    "    bd: 年龄。注意：年龄数据有离群点\n",
    "\n",
    "    gender：性别\n",
    "\n",
    "    registered_via: 注册方式\n",
    "\n",
    "    registration_init_time: 注册时间，格式为%Y%m%d\n",
    "\n",
    "    expiration_date: 到期时间，格式为 %Y%m%d\n",
    "\n",
    "6. song_extra_infos.csv：歌曲额外的信息\n",
    "\n",
    "    song_id：歌曲id\n",
    "\n",
    "    song name ：歌曲名字\n",
    "\n",
    "    isrc – 国际标准音像制品编码(International Standard Recording Code )。理论上可用于歌曲id，但产生的ISR没有经过官方授权。因此ISRC中的信息，如国家代码和参考年份可能不正确。且多首歌曲可能共享共一个ISRC，因为一首歌曲的音像制可发行多次。\n",
    "    \n",
    "**考虑训练数据集比较大，而且特征维数比较高，先用Logistic回归模型试一下效果**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 导入工具包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle as cPickle\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "dpath = '../data/'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>msno_label</th>\n",
       "      <th>song_id_label</th>\n",
       "      <th>source_system_tab_missing</th>\n",
       "      <th>source_screen_name_missing</th>\n",
       "      <th>source_type_missing</th>\n",
       "      <th>city_1</th>\n",
       "      <th>city_3</th>\n",
       "      <th>city_4</th>\n",
       "      <th>city_5</th>\n",
       "      <th>...</th>\n",
       "      <th>language_-1.0</th>\n",
       "      <th>language_3.0</th>\n",
       "      <th>language_10.0</th>\n",
       "      <th>language_17.0</th>\n",
       "      <th>language_24.0</th>\n",
       "      <th>language_31.0</th>\n",
       "      <th>language_38.0</th>\n",
       "      <th>language_45.0</th>\n",
       "      <th>language_52.0</th>\n",
       "      <th>language_59.0</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0.420563</td>\n",
       "      <td>0.339452</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>0.420563</td>\n",
       "      <td>0.605356</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>0.023151</td>\n",
       "      <td>0.103857</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>0.044970</td>\n",
       "      <td>0.623283</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>0.044970</td>\n",
       "      <td>0.237793</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 250 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   id  msno_label  song_id_label  source_system_tab_missing  \\\n",
       "0   0    0.420563       0.339452                        0.0   \n",
       "1   1    0.420563       0.605356                        0.0   \n",
       "2   2    0.023151       0.103857                        0.0   \n",
       "3   3    0.044970       0.623283                        0.0   \n",
       "4   4    0.044970       0.237793                        0.0   \n",
       "\n",
       "   source_screen_name_missing  source_type_missing  city_1  city_3  city_4  \\\n",
       "0                         0.0                  0.0     1.0     0.0     0.0   \n",
       "1                         0.0                  0.0     1.0     0.0     0.0   \n",
       "2                         1.0                  0.0     1.0     0.0     0.0   \n",
       "3                         0.0                  0.0     0.0     1.0     0.0   \n",
       "4                         0.0                  0.0     0.0     1.0     0.0   \n",
       "\n",
       "   city_5  ...  language_-1.0  language_3.0  language_10.0  language_17.0  \\\n",
       "0     0.0  ...            0.0           1.0            0.0            0.0   \n",
       "1     0.0  ...            0.0           1.0            0.0            0.0   \n",
       "2     0.0  ...            0.0           0.0            0.0            1.0   \n",
       "3     0.0  ...            0.0           0.0            0.0            0.0   \n",
       "4     0.0  ...            1.0           0.0            0.0            0.0   \n",
       "\n",
       "   language_24.0  language_31.0  language_38.0  language_45.0  language_52.0  \\\n",
       "0            0.0            0.0            0.0            0.0            0.0   \n",
       "1            0.0            0.0            0.0            0.0            0.0   \n",
       "2            0.0            0.0            0.0            0.0            0.0   \n",
       "3            0.0            0.0            0.0            0.0            1.0   \n",
       "4            0.0            0.0            0.0            0.0            0.0   \n",
       "\n",
       "   language_59.0  \n",
       "0            0.0  \n",
       "1            0.0  \n",
       "2            0.0  \n",
       "3            0.0  \n",
       "4            0.0  \n",
       "\n",
       "[5 rows x 250 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test = pd.read_csv(dpath + 'LR_data/Merge_Test_Scaler.csv')\n",
    "test.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = test.drop(['msno_label','song_id_label'], axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 暂存id，用于保存特征变换后的结果并用于结果提交\n",
    "test_id = test['id']   \n",
    "X_test = test.drop([\"id\"], axis=1)\n",
    "\n",
    "# 保存特征名字以备后用（可视化）\n",
    "feat_names = X_test.columns \n",
    "\n",
    "# 稀疏数据输入，模型训练会快很多\n",
    "from scipy.sparse import csr_matrix\n",
    "X_test = csr_matrix(X_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## load训练好的模型进行测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr_saga_best = cPickle.load(open(dpath + \"LR_data/kkbox_LR_Saga.pkl\", 'rb'))\n",
    "#输出每类的概率\n",
    "y_test_pred = lr_saga_best.predict_proba(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2556788, 2)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_test_pred.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 生成提交结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_df = pd.DataFrame(y_test_pred)\n",
    "out_df = pd.concat([test_id, out_df], axis = 1)\n",
    "out_df.to_csv(dpath + \"LR_data/kkbox_LR_Result.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0.439549</td>\n",
       "      <td>0.560451</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>0.470552</td>\n",
       "      <td>0.529448</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>0.824987</td>\n",
       "      <td>0.175013</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>0.520828</td>\n",
       "      <td>0.479172</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>0.575428</td>\n",
       "      <td>0.424572</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>5</td>\n",
       "      <td>0.462753</td>\n",
       "      <td>0.537247</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>6</td>\n",
       "      <td>0.452485</td>\n",
       "      <td>0.547515</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>7</td>\n",
       "      <td>0.452481</td>\n",
       "      <td>0.547519</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>8</td>\n",
       "      <td>0.481217</td>\n",
       "      <td>0.518783</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>9</td>\n",
       "      <td>0.438436</td>\n",
       "      <td>0.561564</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>10</td>\n",
       "      <td>0.471569</td>\n",
       "      <td>0.528431</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>11</td>\n",
       "      <td>0.620833</td>\n",
       "      <td>0.379167</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>12</td>\n",
       "      <td>0.467774</td>\n",
       "      <td>0.532226</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>13</td>\n",
       "      <td>0.536455</td>\n",
       "      <td>0.463545</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>14</td>\n",
       "      <td>0.510725</td>\n",
       "      <td>0.489275</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>15</td>\n",
       "      <td>0.492768</td>\n",
       "      <td>0.507232</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>16</td>\n",
       "      <td>0.537207</td>\n",
       "      <td>0.462793</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>17</td>\n",
       "      <td>0.441161</td>\n",
       "      <td>0.558839</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>18</td>\n",
       "      <td>0.460471</td>\n",
       "      <td>0.539529</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>19</td>\n",
       "      <td>0.446975</td>\n",
       "      <td>0.553025</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>20</td>\n",
       "      <td>0.446023</td>\n",
       "      <td>0.553977</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>21</td>\n",
       "      <td>0.446205</td>\n",
       "      <td>0.553795</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>22</td>\n",
       "      <td>0.475203</td>\n",
       "      <td>0.524797</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>23</td>\n",
       "      <td>0.517157</td>\n",
       "      <td>0.482843</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>24</td>\n",
       "      <td>0.436550</td>\n",
       "      <td>0.563450</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>25</td>\n",
       "      <td>0.455192</td>\n",
       "      <td>0.544808</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>26</td>\n",
       "      <td>0.478275</td>\n",
       "      <td>0.521725</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>27</td>\n",
       "      <td>0.455212</td>\n",
       "      <td>0.544788</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>28</td>\n",
       "      <td>0.453723</td>\n",
       "      <td>0.546277</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>29</td>\n",
       "      <td>0.557706</td>\n",
       "      <td>0.442294</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    id         0         1\n",
       "0    0  0.439549  0.560451\n",
       "1    1  0.470552  0.529448\n",
       "2    2  0.824987  0.175013\n",
       "3    3  0.520828  0.479172\n",
       "4    4  0.575428  0.424572\n",
       "5    5  0.462753  0.537247\n",
       "6    6  0.452485  0.547515\n",
       "7    7  0.452481  0.547519\n",
       "8    8  0.481217  0.518783\n",
       "9    9  0.438436  0.561564\n",
       "10  10  0.471569  0.528431\n",
       "11  11  0.620833  0.379167\n",
       "12  12  0.467774  0.532226\n",
       "13  13  0.536455  0.463545\n",
       "14  14  0.510725  0.489275\n",
       "15  15  0.492768  0.507232\n",
       "16  16  0.537207  0.462793\n",
       "17  17  0.441161  0.558839\n",
       "18  18  0.460471  0.539529\n",
       "19  19  0.446975  0.553025\n",
       "20  20  0.446023  0.553977\n",
       "21  21  0.446205  0.553795\n",
       "22  22  0.475203  0.524797\n",
       "23  23  0.517157  0.482843\n",
       "24  24  0.436550  0.563450\n",
       "25  25  0.455192  0.544808\n",
       "26  26  0.478275  0.521725\n",
       "27  27  0.455212  0.544788\n",
       "28  28  0.453723  0.546277\n",
       "29  29  0.557706  0.442294"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out_df.head(30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
